#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>

using namespace std;

int main()
{
	int n,m;
	cin>>n>>m;
	vector<int> a;
	for(int i=1;i<=n;i++)
	{
		int num;
		cin>>num;
		a.push_back(num);
	}
	sort(a.begin(),a.end());
	int left=0,right=0;
	int ret=0x3f3f3f,sum=0;
	while(right<n)
	{
		sum+=abs(a[right+1]*a[right+1]-a[right]*a[right]);
		while(left<right && sum>ret)
		{
			sum-=abs(a[left+1]*a[left+1]-a[left]*a[left]);
			left++;
		}
		if(right-left+1==m-1)
		{
			ret=min(ret,sum);
			sum-=abs(a[left+1]*a[left+1]-a[left]*a[left]);
			left++;
		}

		right++;
	}
	cout<<ret<<endl;
	return 0;
}
